查看原文
其他

机器学习预测乳腺肿瘤性质(5)

2017-12-13 汪君 Python爱好者社区

作者:汪君,专职数据分析,Python和R爱好者

个人微信公众号:学英文和玩数据


前文传送门:

机器学习预测乳腺肿瘤性质(1)

机器学习预测乳腺肿瘤性质(2)

机器学习预测乳腺肿瘤性质(3)——贝叶斯分类器

机器学习预测乳腺肿瘤性质(4)——神经网络

前面几次的交叉验证,我们都利用scikit-learn里的accuracy指标来评价分类器在测试集上的性能。




但是accuracy作为分类器评价指标存在一定局限性,假如我们面临一个二分类问题,测试数据集A(size=100)里有90个样例分类标签是0,还有10个实际分类标签是1。我们构建一个不借助任何机器学习算法的分类器(暂取名zero classifier),无论什么测试数据,zero classifier一直预测分类标签是0。这样的分类器在前述测试数据集A上的表现如何?它也可以达到90%的准确率,与前面人工神经网络的预测水平相当,这说明accuracy在面对比较skewed的数据集时作分类器的评价指标会存在问题。

其实机器学习领域并非只有accuracy一个评价指标,常用其他指标还有confusion matrix,precision,recall,F1-score,ROC curve等,通过乳腺肿瘤的数据集,我们来对这些指标逐一解读。

一、confusion matrix (混淆矩阵)

首先看confusion matrix,中文翻译为混淆矩阵,sklearn里的函数confusion_matrix可以为我们返回混淆矩阵。混淆矩阵的大小为n*n,n表示所有类别的数量,每行表示实际类别数量,每列表示预测出来的类别数量,通过矩阵的形式可以看到分类器在每个类别的预测成绩。

还是以二分类(0,1)问题举例,下面这个混淆矩阵的解读方法:实际分类为0的数据有50个,其中分类器正确预测了其中48个,把2个错分到类别1,实际分类为1的数据是58个,分类器正确预测了51个,把7个错分到了0类别。在混淆矩阵中,我们希望主对角线(左上至右下)上的数字越大好,副对角线(左下到右上)上的数字越小越好。



多分类变量按照同样的方式去解读。

如果利用混淆矩阵来分析zero classifier在测试集A上的表现:那就是:



显然按照混淆矩阵来评判,zero classifier这样的分类器在测试集A上效果并不理想。

我们利用confusion_matrix看看前面几个分类器表现

from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import StratifiedShuffleSplit cut=StratifiedShuffleSplit(n_splits=1,random_state=42,test_size=0.2)
for a,b in cut.split(X=predi_features,y=diagnosis):    train_x,train_y=predi_features.iloc[a,:],diagnosis[a]    test_x,test_y=predi_features.iloc[b,:],diagnosis[b]

## logistic regression 的confusion matrix

from sklearn.linear_model import LogisticRegression logClf=LogisticRegression() test_prediction_log=logClf.fit(train_x,train_y).predict(test_x) print('confusion matrix for logistical classifier is:') print(confusion_matrix(test_y,test_prediction_log))

## SVM 的confusion matrix

from sklearn.svm import SVC svmClf=SVC(kernel="linear",C=1) test_prediction_svm=svmClf.fit(train_x,train_y).predict(test_x) print('confusion matrix for SVM classifier is:') print(confusion_matrix(test_y,test_prediction_svm))

## 随机森林的confusion matrix

from sklearn.ensemble import RandomForestClassifier rfClf=RandomForestClassifier() test_prediction_rf=rfClf.fit(train_x,train_y).predict(test_x) print('confusion matrix for randomForest classifier is:') print(confusion_matrix(test_y,test_prediction_rf))

## 神经网络的confusion matrix

from sklearn.neural_network import MLPClassifier MLPclf=MLPClassifier(activation="relu",batch_size=80,                  hidden_layer_sizes=(30,30,30,15,55,15,15,30),                  max_iter=2000,random_state=42,learning_rate_init=0.001) test_prediction_MLP=MLPclf.fit(train_x,train_y).predict(test_x) print('confusion matrix for MLP classifier is:') print(confusion_matrix(test_y,test_prediction_MLP))

输出结果

confusion matrix for logistical classifier is: 

[[71    1] 

 [ 6    36]] 

confusion matrix for SVM classifier is: 

[[70     2] 

 [ 7    35]] 

confusion matrix for randomForest classifier is: 

[[72    0] 

 [ 4    38]] 

confusion matrix for MLP classifier is: 

[[70    2]  

[ 6     36]]

可以看到Random Forrest在预测良性和恶性方面都突出一点。

二、precision ,recall and F1-score.

首先看看它们的定义:

precision 表示精准程度,在分类器预测的阳性中,有多大比例是真阳性,recall表示灵敏程度,即分类器能将所有真阳性中多大比例的样本预测出来,一般来讲,两者之间存在一个折衷,精准度非常高,那灵敏度就相对不那么高,反之亦然。而从F1_score的计算公式来看,它把两个指标综合起来,要precision和recall都相对较高时,才能得到一个较高的F1_score.

在scikit-learn中也有现成的函数计算precision_score,recall_scoref1_score.

from sklearn.metrics import precision_score,recall_score,f1_score

## 注意这里要把label标签转换成0,1变量而非字符串

print("precision for SVM is %.2f" % (precision_score(test_y=='M',test_prediction_svm=="M"))) print("recall for SVM is %.2f" %(recall_score(test_y=='M',test_prediction_svm=="M"))) print("F1_score for SVM is %.2f" %(f1_score(test_y=='M',test_prediction_svm=="M"))) print("F1_score for randomforest is %0.2f" %(f1_score(test_y=='M',test_prediction_rf=="M"))) print("F1_score for logisticRegression is %.2f" %(f1_score(test_y=='M',test_prediction_log=="M"))) print("F1_score for MLP is %.2f" %(f1_score(test_y=='M',test_prediction_MLP=="M")))

评价结果

precision for SVM is 0.95 recall for SVM is 0.83 F1_score for SVM is 0.89 F1_score for randomforest is 0.95
F1_score for logisticRegression is 0.91 F1_score for MLP is 0.90


综合得分F1-score来看随机森林要高一些。

三、 ROC curve 和AUC


ROC曲线是根据某一分类器在取不同分类临界值时的分类表现所绘制的曲线,横坐标为假阳性率(False Positive Rate,FPR),纵坐标为真阳性率(True Positive Rate,也就是前面的recall或者灵敏度sensitivity),FPR=1-specificity(特异度,判定阴性的能力),ROC曲线所表示的内容就是sensitivity VS (1-specificity),ROC曲线综合考虑分类器在判别过程中的灵敏度和特异度,也常常用来作为分类器性能对比的评价指标,如下图。



在实际应用中,我们希望特异度和灵敏度两者同时越大越好,即理想情况下,我们希望不管分类阈值多少,灵敏度和特异度都到达1,ROC都落在[0, 1]这个坐标点。但现实情况是,随着灵敏度的增大,特异度会变小,1-特异度变大,出现了图中红色的ROC曲线,但是我们依然希望ROC曲线能逼近坐标图中的左上方,于是我们采用图中红线围成部分的面积AUC(area under curve)来衡量分类器的性能。AUC越大,该分类器的判别价值就越大。(绿色虚线表示完全随机的分类器的ROC曲线)。

scikit-learn里也提供了roc_curve函数,计算不同阈值下TPR和FPR,有了这两个值,我们就可以绘制ROC曲线,下面以random Forest和logistic regression 两个分类器为例绘制ROC曲线。

from sklearn.metrics import roc_curve y_scores=cross_val_predict(rfClf,train_x,train_y,cv=3,method="predict_proba") y_scores_logit=cross_val_predict(logClf,train_x,train_y,cv=3,method="predict_proba") y_scores_rf = y_scores[:, 1] y_scores_log=y_scores_logit[:, 1] fpr_log,tpr_log,t_log=roc_curve(train_y=='M',y_scores_log) fpr,tpr,thres=roc_curve(train_y=='M',y_scores_rf) plt.figure(figsize=(10,5)) plt.plot(fpr,tpr,linewidth=2,c="r",label="randomForest") plt.plot(fpr_log,tpr_log,"b:",label="logisticRegression") plt.plot([0,1],[0,1],"k--") plt.axis([0,1,0,1]) plt.legend(loc="lower right") plt.show()


同样也可以利用函数roc_auc_score计算ROC曲线下面积AUC

from sklearn.metrics import roc_auc_score print('the AUC for logisticRegression is %.4f'%(roc_auc_score(train_y=="M", y_scores_log))) print('the AUC for randomForest is %.4f'%(roc_auc_score(train_y=="M", y_scores_rf)))

the AUC for logisticRegression is 0.9894

 the AUC for randomForest is 0.9842

如果本笔记对您有帮助,请动动手帮我点个赞吧

获取本文数据:下图扫码关注Python爱好者公众号,后台回复 data

Python爱好者社区历史文章大合集

Python爱好者社区历史文章列表(每周append更新一次)

福利:文末扫码立刻关注公众号,“Python爱好者社区”,开始学习Python课程:

关注后在公众号内回复“课程”即可获取:

0.小编的Python入门视频课程!!!

1.崔老师爬虫实战案例免费学习视频。

2.丘老师数据科学入门指导免费学习视频。

3.陈老师数据分析报告制作免费学习视频。

4.玩转大数据分析!Spark2.X+Python 精华实战课程免费学习视频。

5.丘老师Python网络爬虫实战免费学习视频。

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存